import importlib.util
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.api as sm


def set_style() -> None:
    try:
        plt.style.use("seaborn-v0_8-whitegrid")
    except Exception:
        pass
    plt.rcParams.update(
        {
            "font.size": 10,
            "axes.titlesize": 12,
            "axes.labelsize": 10,
            "xtick.labelsize": 9,
            "ytick.labelsize": 9,
            "legend.fontsize": 9,
            "axes.spines.top": False,
            "axes.spines.right": False,
            "figure.dpi": 140,
        }
    )


def load_shiftshare_module(root: Path):
    path = root / "analysis" / "19_iv_shiftshare_rollout_absorb.py"
    spec = importlib.util.spec_from_file_location("shiftshare", path)
    mod = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(mod)
    return mod


def fit_ols_absorbed(mod, df: pd.DataFrame, y: str, controls: list[str]):
    cols = [y, "netusoft", "edu_high", "ct", "region", *controls]
    d = df[cols].dropna().copy()
    if d.empty:
        raise ValueError(f"No rows after dropna for outcome={y}")
    d["netusoft_x_edu"] = d["netusoft"] * d["edu_high"]

    model_cols = [y, "netusoft", "netusoft_x_edu", *controls, "edu_high"]
    r = mod.absorb_two_way(d, model_cols, fe_a="region", fe_b="ct")

    y_r = r[y]
    X = r[controls + ["edu_high", "netusoft", "netusoft_x_edu"]]
    X = sm.add_constant(X, has_constant="add")
    res = sm.OLS(y_r, X).fit(cov_type="cluster", cov_kwds={"groups": d["region"]})
    return res


def combine_effects(b_low: float, b_delta: float, cov: np.ndarray) -> tuple[float, float]:
    """
    Returns (b_high, se_high) for b_high = b_low + b_delta.
    cov is 2x2 covariance of (b_low, b_delta).
    """
    b_high = b_low + b_delta
    var = float(cov[0, 0] + cov[1, 1] + 2 * cov[0, 1])
    se = float(np.sqrt(max(var, 0.0)))
    return b_high, se


def coef_ci(b: float, se: float, z: float = 1.96) -> tuple[float, float]:
    return b - z * se, b + z * se


def plot_effect_panel(ax, rows: list[dict], title: str):
    # rows: each has outcome_label, b_low, se_low, b_high, se_high
    ylabels = [r["label"] for r in rows]
    y = np.arange(len(rows))[::-1]
    ax.axvline(0, color="#888888", lw=1)

    # Low-edu
    b_low = np.array([r["b_low"] for r in rows])[::-1]
    se_low = np.array([r["se_low"] for r in rows])[::-1]
    lo_low, hi_low = coef_ci(b_low, se_low)

    # High-edu
    b_high = np.array([r["b_high"] for r in rows])[::-1]
    se_high = np.array([r["se_high"] for r in rows])[::-1]
    lo_high, hi_high = coef_ci(b_high, se_high)

    ax.errorbar(b_low, y - 0.12, xerr=[b_low - lo_low, hi_low - b_low], fmt="o", color="#1f77b4", label="Low education")
    ax.errorbar(
        b_high, y + 0.12, xerr=[b_high - lo_high, hi_high - b_high], fmt="o", color="#ff7f0e", label="High education"
    )

    ax.set_yticks(y)
    ax.set_yticklabels(ylabels[::-1], fontsize=9)
    ax.set_title(title, fontsize=11)
    ax.set_xlabel("Estimated effect (LPM / index units)")
    ax.grid(axis="x", color="#e6e6e6", lw=0.8)


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    fig_dir = root / "paper_joc" / "figures"
    fig_dir.mkdir(parents=True, exist_ok=True)

    set_style()

    mod = load_shiftshare_module(root)
    df = mod.build_dataset(root)

    outcomes = [
        ("y_vote", "Vote (yes/no)"),
        ("y_part_index", "Participation index (non-electoral)"),
        ("y_sgnptit", "Signed petition"),
        ("y_contplt", "Contacted politician/official"),
        ("y_badge", "Worn campaign badge/sticker"),
        ("y_bctprd", "Boycotted products"),
    ]
    short_outcome_labels = {
        "y_vote": "Vote",
        "y_part_index": "Part. index",
        "y_sgnptit": "Petition",
        "y_contplt": "Contact",
        "y_badge": "Badge",
        "y_bctprd": "Boycott",
    }
    controls = ["agea", "gndr", "hinctnta"]

    # Figure 1: OLS vs 2SLS effect sizes (low/high edu)
    ols_rows = []
    iv_rows = []
    fs_rows = []
    for y, label in outcomes:
        ols = fit_ols_absorbed(mod, df, y, controls=controls)
        b_low = float(ols.params["netusoft"])
        b_delta = float(ols.params["netusoft_x_edu"])
        cov = ols.cov_params().loc[["netusoft", "netusoft_x_edu"], ["netusoft", "netusoft_x_edu"]].to_numpy()
        b_high, se_high = combine_effects(b_low, b_delta, cov)
        ols_rows.append(
            {
                "outcome": y,
                "label": label,
                "b_low": b_low,
                "se_low": float(ols.bse["netusoft"]),
                "b_high": b_high,
                "se_high": se_high,
            }
        )

        iv = mod.fit_iv_interaction_absorb(df, y, controls=controls)
        b_low_iv = float(iv.params["netusoft"])
        b_delta_iv = float(iv.params["netusoft_x_edu"])
        cov_iv = iv.cov.loc[["netusoft", "netusoft_x_edu"], ["netusoft", "netusoft_x_edu"]].to_numpy()
        b_high_iv, se_high_iv = combine_effects(b_low_iv, b_delta_iv, cov_iv)
        iv_rows.append(
            {
                "outcome": y,
                "label": label,
                "b_low": b_low_iv,
                "se_low": float(iv.std_errors["netusoft"]),
                "b_high": b_high_iv,
                "se_high": se_high_iv,
            }
        )

        diag = iv.first_stage.diagnostics
        fs_rows.append(
            {
                "label": short_outcome_labels.get(y, label),
                "F_net": float(diag.loc["netusoft", "f.stat"]),
                "F_int": float(diag.loc["netusoft_x_edu", "f.stat"]),
            }
        )

    fig1, axes = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
    plot_effect_panel(axes[0], ols_rows, "Panel A. Absorbed-FE OLS")
    plot_effect_panel(axes[1], iv_rows, "Panel B. Rollout-IV 2SLS (shift-share)")
    axes[0].legend(loc="lower right", fontsize=9, frameon=False)
    axes[1].legend().remove()
    fig1.tight_layout()
    fig1.savefig(fig_dir / "fig1_ols_vs_iv_panels.png", dpi=260)
    fig1.savefig(fig_dir / "fig1_ols_vs_iv_panels.pdf")
    plt.close(fig1)

    # Figure 2 (main paper): First stage + placebo (keep randomization inference for supplement Fig S1)
    fs = pd.DataFrame(fs_rows)
    # Placebos (IV): gender and born-in-country
    placebo_outcomes = [
        ("gndr", "Female"),
        ("brncntr_bin", "Native-born"),
    ]
    placebo_rows = []
    for y, label in placebo_outcomes:
        if y not in df.columns:
            continue
        iv = mod.fit_iv_interaction_absorb(df, y, controls=["agea", "gndr", "hinctnta"] if y != "gndr" else ["agea", "hinctnta"])
        b_low_iv = float(iv.params["netusoft"])
        b_delta_iv = float(iv.params["netusoft_x_edu"])
        cov_iv = iv.cov.loc[["netusoft", "netusoft_x_edu"], ["netusoft", "netusoft_x_edu"]].to_numpy()
        b_high_iv, se_high_iv = combine_effects(b_low_iv, b_delta_iv, cov_iv)
        placebo_rows.append(
            {
                "label": label,
                "b_low": b_low_iv,
                "se_low": float(iv.std_errors["netusoft"]),
                "b_high": b_high_iv,
                "se_high": se_high_iv,
            }
        )

    fig2, axes = plt.subplots(
        1,
        2,
        figsize=(12.5, 4.6),
        gridspec_kw={"width_ratios": [1.05, 0.95]},
    )

    # Panel A: first-stage F
    ax = axes[0]
    x = np.arange(len(fs))
    width = 0.38
    ax.bar(x - width / 2, fs["F_net"], width, label="F(netusoft)", color="#4c78a8", edgecolor="none")
    ax.bar(x + width / 2, fs["F_int"], width, label="F(netusoft×edu)", color="#f58518", edgecolor="none")
    ax.set_xticks(x)
    ax.set_xticklabels(fs["label"], rotation=0, ha="center")
    ax.set_ylabel("First-stage F")
    ax.set_title("Panel A. First-stage strength (2SLS)")
    ax.axhline(10, color="#777777", lw=1, ls="--")
    ax.text(x=-0.55, y=10.35, s="F=10", fontsize=9, color="#555555", ha="left", va="bottom", clip_on=False)
    ymax = float(np.nanmax([fs["F_net"].max(), fs["F_int"].max()]))
    ax.set_ylim(0, max(20.0, ymax + 8.0))
    ax.grid(axis="y", color="#e6e6e6", lw=0.8)
    ax.legend(frameon=True, facecolor="white", framealpha=0.9, edgecolor="none", loc="upper left", ncol=1)

    # Panel B: placebo coefficient plot
    ax = axes[1]
    ax.axvline(0, color="#888888", lw=1)
    y = np.arange(len(placebo_rows))
    b_low = np.array([r["b_low"] for r in placebo_rows])
    se_low = np.array([r["se_low"] for r in placebo_rows])
    lo_low, hi_low = coef_ci(b_low, se_low)
    b_high = np.array([r["b_high"] for r in placebo_rows])
    se_high = np.array([r["se_high"] for r in placebo_rows])
    lo_high, hi_high = coef_ci(b_high, se_high)
    ax.errorbar(
        b_low,
        y + 0.12,
        xerr=[b_low - lo_low, hi_low - b_low],
        fmt="o",
        color="#1f77b4",
        elinewidth=2,
        capsize=0,
        label="Low edu",
    )
    ax.errorbar(
        b_high,
        y - 0.12,
        xerr=[b_high - lo_high, hi_high - b_high],
        fmt="o",
        color="#ff7f0e",
        elinewidth=2,
        capsize=0,
        label="High edu",
    )
    ax.set_yticks(y)
    ax.set_yticklabels([r["label"] for r in placebo_rows], fontsize=10)
    ax.invert_yaxis()
    ax.set_xlabel("Estimated effect")
    ax.set_title("Panel B. Placebo outcomes (2SLS)")
    ax.grid(axis="x", color="#e6e6e6", lw=0.8)
    ax.set_xlim(-0.45, 0.35)
    ax.legend(frameon=True, facecolor="white", framealpha=0.9, edgecolor="none", loc="lower right", ncol=1)

    fig2.tight_layout(pad=1.2, w_pad=2.0)
    fig2.savefig(fig_dir / "fig2_firststage_placebo_panels.png", dpi=260)
    fig2.savefig(fig_dir / "fig2_firststage_placebo_panels.pdf")
    plt.close(fig2)

    print(f"Wrote figures to: {fig_dir}")


if __name__ == "__main__":
    main()
